import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier

import argparse
import time
import warnings

import sys
sys.path.append('../code')
from classification import *

from pathlib import Path
import pickle

def parse_commandline():
    """Parse the arguments given on the command-line.
    """
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--params", nargs='+', default=None)

    args = parser.parse_args()

    return args

###############################################################################
# BEGIN MAIN FUNCTION
###############################################################################
if __name__ == '__main__':
    warnings.simplefilter('ignore')
    
    # set the seed
    args = parse_commandline()
    seed = int(args.params[0])
    np.random.seed(seed)

    # set the number of bootstrap samples
    n_bs = 10000

    # whether to balance the train set size between the groups
    ## 1.0 means perfectly balanced
    ## None means no balancing
    # ratio = 1.0
    ratio = None
    test_size = 0.33
    delta = 0.01

    # load the data
    df = pd.read_csv('../data/df_health_risk.csv')
    Y_column = 'gagne_sum_t'
    G_column = 'dem_race_black'
    Y_columns = ['cost_t', 'gagne_sum_t']
    X_column = list(set(df.columns) - set(Y_columns) - set([G_column]))
    X_column.sort()

    y = df[Y_column].values
    thres_percentile = 97
    y_thres = np.percentile(y, thres_percentile)
    y_binary = (df['gagne_sum_t'].values > y_thres)*1
    df[Y_column] = y_binary
    
    # train models
    clf_b = LogisticRegression(max_iter=1000)
    clf_r = LogisticRegression(max_iter=1000)

    res, X_test_blue, y_test_blue, X_test_red, y_test_red =\
        classification_onefold(df, Y_column, G_column, X_column,
                               clf_b, clf_r, seed=seed, ratio=ratio,
                               verbose=True, test_size=test_size,
                               standardize=False)

    clf_b = res['clf_b']
    clf_r = res['clf_r']
    e_b_B = res['res_blue']['e_b']
    e_r_B = res['res_blue']['e_r']
    e_b_R = res['res_red']['e_b']
    e_r_R = res['res_red']['e_r']
    
    # compute the p-values
    p_b, p_r, e_b_B_list, e_r_B_list, e_b_R_list, e_r_R_list =\
        bs_p_value_same_optimal(clf_b, clf_r, X_test_blue, y_test_blue,
                                X_test_red, y_test_red,
                                e_b_B, e_r_B, e_b_R, e_r_R,
                                delta=delta, n_bs=n_bs, seed=42)
    
    # save the results
    filepath = '/home/kop5674/FA_frontier/result/science/'
    
    np.save(filepath + f'test_sameopt_logreg_e_b_B_list_{seed}_{n_bs}.npy', e_b_B_list)
    np.save(filepath + f'test_sameopt_logreg_e_r_B_list_{seed}_{n_bs}.npy', e_r_B_list)
    np.save(filepath + f'test_sameopt_logreg_e_b_R_list_{seed}_{n_bs}.npy', e_b_R_list)
    np.save(filepath + f'test_sameopt_logreg_e_r_R_list_{seed}_{n_bs}.npy', e_r_R_list)
    
    filename = f'test_sameopt_logreg_{delta}_{seed}_{ratio}.pkl'
    res = {'p_b': p_b, 'p_r': p_r, 'e_b_B': e_b_B, 'e_r_B': e_r_B,
           'e_b_R': e_b_R, 'e_r_R': e_r_R}
    
    print(res)
    
    with open(filepath + filename, 'wb') as f:
        pickle.dump(res, f)